import importlib
from utils import Timer


class MLFlow:
    def __init__(self, log_dir, logger, enabled):
        self.mlflow = None

        if enabled:
            log_dir = str(log_dir)

            # Retrieve visualization writer.
            try:
                self.mlflow = importlib.import_module("mlflow")
                succeeded = True
            except ImportError:
                succeeded = False

            if not succeeded:
                message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
                          "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
                          "the 'config.json' file."
                logger.warning(message)

        self.step = 0
        self.mode = ''

        self.mlflow_ftns_with_tag_and_value = {
            'log_param', 'log_metric'
        }
        self.mlflow_ftns = {
            'start_run'
        }
        # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}

        # self.timer = Timer()

    # def set_step(self, step, mode='train'):
    #     self.mode = mode
    #     self.step = step
    #     if step == 0:
    #         self.timer.reset()
    #     else:
    #         duration = self.timer.check()
    #         self.add_scalar('steps_per_sec', 1 / duration)

    def __getattr__(self, name):
        """
        If visualization is configured to use:
            return add_data() methods of tensorboard with additional information (step, tag) added.
        Otherwise:
            return a blank function handle that does nothing
        """
        if name in self.mlflow_ftns_with_tag_and_value:
            add_data = getattr(self.mlflow, name, None)

            def wrapper(tag, data, *args, **kwargs):
                if add_data is not None:
                    # add mode(train/valid) tag
                    if name not in self.tag_mode_exceptions:
                        tag = '{}/{}'.format(tag, self.mode)
                    add_data(tag, data, *args, **kwargs)

            return wrapper
        elif name in self.mlflow_ftns:
            add_data = getattr(self.mlflow, name, None)

            def wrapper(*args, **kwargs):
                if add_data is not None:
                    # add mode(train/valid) tag
                    # if name not in self.tag_mode_exceptions:
                    #     tag = '{}/{}'.format(tag, self.mode)
                    add_data(*args, **kwargs)

            return wrapper
        else:
            # default action for returning methods defined in this class, set_step() for instance.
            try:
                attr = object.__getattr__(name)
            except AttributeError:
                raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
            return attr


class TensorboardWriter:
    def __init__(self, log_dir, logger, enabled):
        self.writer = None
        self.selected_module = ""

        if enabled:
            log_dir = str(log_dir)

            # Retrieve vizualization writer.
            succeeded = False
            for module in ["torch.utils.tensorboard", "tensorboardX"]:
                try:
                    self.writer = importlib.import_module(module).SummaryWriter(log_dir)
                    succeeded = True
                    break
                except ImportError:
                    succeeded = False
                self.selected_module = module

            if not succeeded:
                message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
                    "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
                    "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
                    "the 'config.json' file."
                logger.warning(message)

        self.step = 0
        self.mode = ''

        self.tb_writer_ftns = {
            'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
            'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
        }
        self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
            
        self.timer = Timer()

    def set_step(self, step, mode='train'):
        self.mode = mode
        self.step = step
        if step == 0:
            self.timer.reset()
        else:
            duration = self.timer.check()
            self.add_scalar('steps_per_sec', 1 / duration)

    def __getattr__(self, name):
        """
        If visualization is configured to use:
            return add_data() methods of tensorboard with additional information (step, tag) added.
        Otherwise:
            return a blank function handle that does nothing
        """
        if name in self.tb_writer_ftns:
            add_data = getattr(self.writer, name, None)

            def wrapper(tag, data, *args, **kwargs):
                if add_data is not None:
                    # add mode(train/valid) tag
                    if name not in self.tag_mode_exceptions:
                        tag = '{}/{}'.format(tag, self.mode)
                    add_data(tag, data, self.step, *args, **kwargs)
            return wrapper
        else:
            # default action for returning methods defined in this class, set_step() for instance.
            try:
                attr = object.__getattr__(name)
            except AttributeError:
                raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
            return attr
